/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 * All rights reserved.
 *
 * This source code is licensed under the BSD-style license found in the
 * LICENSE file in the root directory of this source tree.
 */

// Generate tokens in a loop.
#pragma once

#include <executorch/extension/llm/runner/stats.h>
#include <executorch/extension/llm/runner/text_decoder_runner.h>
#include <executorch/extension/tensor/tensor.h>
#include <pytorch/tokenizers/tokenizer.h>

namespace executorch {
namespace extension {
namespace llm {

class ET_EXPERIMENTAL TextTokenGenerator {
 public:
  TextTokenGenerator(
      ::tokenizers::Tokenizer* tokenizer,
      TextDecoderRunner* text_decoder_runner,
      bool use_kv_cache,
      std::unique_ptr<std::unordered_set<uint64_t>>&& eos_ids,
      Stats* stats)
      : tokenizer_(tokenizer),
        text_decoder_runner_(text_decoder_runner),
        eos_ids_(std::move(eos_ids)),
        use_kv_cache_(use_kv_cache),
        stats_(stats) {}

  virtual ~TextTokenGenerator() = default;

  /**
   * Token generation loop.
   * @param tokens The first token generated by prefill, if using kv cache. Else
   * the prompt tokens + the first token generated by prefill.
   * @param start_pos The start position of the new tokens, based on how many
   * prompt tokens is prefilled.
   * @param max_new_tokens Maximum number of new tokens to generate.
   * @param temperature controls the randomness of predictions by scaling the
   * logits before applying softmax. A higher temperature results in more
   * random predictions, while a lower temperature results in more deterministic
   * predictions.
   * @param token_callback what to do after a token is generated.
   * @return how many tokens are generated.
   */
  inline ::executorch::runtime::Result<int64_t> generate(
      std::vector<uint64_t> tokens,
      int64_t start_pos,
      int32_t max_new_tokens,
      float temperature = 0.0f,
      const std::function<void(const std::string&)>& token_callback = {}) {
    ET_CHECK_MSG(
        !tokens.empty(), "Token generation loop shouldn't take empty tokens");
    int64_t pos = start_pos; // position in the sequence

    std::vector<uint64_t> token_data; // allocate space for the tokens
    std::vector<executorch::aten::SizesType> token_shape;

    // Token after prefill
    uint64_t cur_token = tokens.back();
    uint64_t prev_token;

    if (use_kv_cache_) {
      // hard code these to size 1 as kv cache is locked to static size right
      // now.
      token_data = {cur_token};
      token_shape = {1, 1};
    } else {
      token_data = tokens;
      token_shape = {1, static_cast<int>(tokens.size())};
    }

    // initialize tensor wrappers
    auto tokens_managed = from_blob(
        token_data.data(), token_shape, executorch::aten::ScalarType::Long);

    should_stop_ = false;

    // Generate our tokens
    while (pos < start_pos + max_new_tokens) {
      // Run the model
      auto logits_res = text_decoder_runner_->step(tokens_managed, pos);

      ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error());
      executorch::aten::Tensor& logits_tensor = logits_res.get();

      prev_token = cur_token;

      stats_->on_sampling_begin();
      cur_token =
          text_decoder_runner_->logits_to_token(logits_tensor, temperature);
      stats_->on_sampling_end();

      pos++;

      if (use_kv_cache_) {
        // update the token tensor. token_data will not be empty.
        // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds)
        token_data[0] = cur_token;
      } else {
        // push it to the back
        token_data.push_back(cur_token);
        ET_CHECK_OK_OR_RETURN_ERROR(resize_tensor_ptr(
            tokens_managed, {1, static_cast<int>(token_data.size())}));
      }

      // print the token as string, decode it with the Tokenizer object
      auto decode_result = tokenizer_->decode(prev_token, cur_token);
      if (!decode_result.ok()) {
        ET_LOG(
            Error,
            "Tokenizers error code %d",
            static_cast<uint32_t>(decode_result.error()));
        return ::executorch::runtime::Error::InvalidArgument;
      }
      token_callback(std::move(*decode_result));

      if (should_stop_) {
        break;
      }

      // data-dependent terminating condition: we have n_eos_ number of EOS
      if (eos_ids_->find(cur_token) != eos_ids_->end()) {
        printf("\n");
        ET_LOG(Info, "\nReached to the end of generation");
        break;
      }
    }
    return pos - start_pos;
  }

  /**
   * Stop the generation loop.
   */
  inline void stop() {
    should_stop_ = true;
  }

  /**
   * Load the necessary resources for TextTokenGenerator.
   * This method should be called before using the generate() method.
   */
  ::executorch::runtime::Error load() {
    return text_decoder_runner_->load();
  }

  /**
   * Check if the TextTokenGenerator has been successfully loaded.
   * @return True if the resources are loaded, false otherwise.
   */
  bool inline is_loaded() const {
    // Implementation to check if resources are loaded
    return tokenizer_->is_loaded() && text_decoder_runner_->is_method_loaded();
  }

 private:
  /**
   * Note: TextTokenGenerator does not own the tokenizer_ and
   * text_decoder_runner_. The lifecycle of these objects should be managed
   * externally, likely in the Runner. This class assumes that the provided
   * pointers remain valid for the duration of its use.
   */
  ::tokenizers::Tokenizer* tokenizer_;
  TextDecoderRunner* text_decoder_runner_;
  std::unique_ptr<std::unordered_set<uint64_t>> eos_ids_;
  bool use_kv_cache_;

  // state machine
  bool should_stop_ = false;

  // stats
  Stats* stats_;
};

} // namespace llm
} // namespace extension
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::extension::llm::TextTokenGenerator;
} // namespace executor
} // namespace torch
